{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7sNELNzFVhIV"
   },
   "source": [
    "![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)\n",
    "\n",
    "This notebook provides a tutorial for using Predictive Sampling with MJX."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "l9lUn3zmWIwP"
   },
   "source": [
    "### Copyright notice"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7WpsAFhAWjmJ"
   },
   "source": [
    "Copyright 2024 DeepMind Technologies Limited.\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aKSPxKQSWm2Z"
   },
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "D7Vucuuug6W5"
   },
   "outputs": [],
   "source": [
    "!pip install mujoco\n",
    "!pip install brax\n",
    "\n",
    "# Set up GPU rendering.\n",
    "from google.colab import files\n",
    "import distutils.util\n",
    "import os\n",
    "import subprocess\n",
    "if subprocess.run('nvidia-smi').returncode:\n",
    "  raise RuntimeError(\n",
    "      'Cannot communicate with GPU. '\n",
    "      'Make sure you are using a GPU Colab runtime. '\n",
    "      'Go to the Runtime menu and select Choose runtime type.')\n",
    "\n",
    "# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
    "# This is usually installed as part of an Nvidia driver package, but the Colab\n",
    "# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
    "# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
    "NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n",
    "if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
    "  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n",
    "    f.write(\"\"\"{\n",
    "    \"file_format_version\" : \"1.0.0\",\n",
    "    \"ICD\" : {\n",
    "        \"library_path\" : \"libEGL_nvidia.so.0\"\n",
    "    }\n",
    "}\n",
    "\"\"\")\n",
    "\n",
    "# Configure MuJoCo to use the EGL rendering backend (requires GPU)\n",
    "print('Setting environment variable to use GPU rendering:')\n",
    "%env MUJOCO_GL=egl\n",
    "\n",
    "# Check if installation was succesful.\n",
    "try:\n",
    "  print('Checking that the installation succeeded:')\n",
    "  import mujoco\n",
    "  mujoco.MjModel.from_xml_string('<mujoco/>')\n",
    "except Exception as e:\n",
    "  raise e from RuntimeError(\n",
    "      'Something went wrong during installation. Check the shell output above '\n",
    "      'for more information.\\n'\n",
    "      'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n",
    "      'by going to the Runtime menu and selecting \"Choose runtime type\".')\n",
    "\n",
    "print('Installation successful.')\n",
    "\n",
    "from IPython.display import clear_output\n",
    "clear_output()\n",
    "print(\"mujoco_mpc.mjx isn't installed yet -- please make sure you install it and fetch the task files from the MuJoCo Menagerie GitHub repository.\")\n",
    "from brax import base as brax_base\n",
    "from brax.io import html\n",
    "from brax.io import mjcf\n",
    "from IPython.display import HTML\n",
    "import jax\n",
    "import matplotlib.pyplot as plt\n",
    "from mujoco import mjx\n",
    "from mujoco_mpc.mjx import predictive_sampling\n",
    "from mujoco_mpc.mjx.tasks import insert\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qDbIhqEP5DtK"
   },
   "source": [
    "# Load task, run optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "RIB7ywe9g6W6"
   },
   "outputs": [],
   "source": [
    "sim_model_cpu, plan_model_cpu, cost_fn, instruction_fn = insert.get_models_and_cost_fn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "XcGmBj61g6W6"
   },
   "outputs": [],
   "source": [
    "costs_to_compare = {}\n",
    "for it in [0.8]:\n",
    "  nsteps = 2000\n",
    "  steps_per_plan = 10\n",
    "  batch_size = 1024\n",
    "  nsamples = 256\n",
    "  nplans = batch_size // nsamples\n",
    "  print(f'nplans: {nplans}')\n",
    "\n",
    "  p = predictive_sampling.Planner(\n",
    "      model=mjx.put_model(plan_model_cpu),\n",
    "      cost=cost_fn,\n",
    "      noise_scale=it,  # iterate on different values\n",
    "      horizon=128,\n",
    "      nspline=4,\n",
    "      nsample=nsamples - 1,\n",
    "      interp='zero',\n",
    "      instruction_fn=instruction_fn,\n",
    "  )\n",
    "\n",
    "  sim_data = mujoco.MjData(sim_model_cpu)\n",
    "  mujoco.mj_resetDataKeyframe(sim_model_cpu, sim_data, 0)\n",
    "  # without kinematics, the first cost is off:\n",
    "  mujoco.mj_forward(sim_model_cpu, sim_data)\n",
    "  sim_data = mjx.put_data(sim_model_cpu, sim_data)\n",
    "  q0s = np.tile(sim_data.qpos, (nplans, 1))\n",
    "  def set_qpos(data, qpos):\n",
    "    return data.replace(qpos=qpos)\n",
    "  sim_datas = jax.vmap(set_qpos, in_axes=(None, 0))(sim_data, q0s)\n",
    "  multi_policy = np.tile(sim_model_cpu.key_ctrl[0, :], (nplans, p.nspline, 1))\n",
    "  mpc_rollout_multiplan = jax.vmap(\n",
    "      predictive_sampling.mpc_rollout, in_axes=(\n",
    "          None,  # nsteps\n",
    "          None,  # steps_per_plan\n",
    "          None,  # Planner\n",
    "          0,     # init_policy\n",
    "          0,     # rng\n",
    "          None,  # sim_model\n",
    "          0,     # sim_data\n",
    "      )\n",
    "  )\n",
    "\n",
    "  sim_datas, final_policy, costs, trajectories, terms = jax.jit(\n",
    "      mpc_rollout_multiplan, static_argnums=[0, 1]\n",
    "  )(\n",
    "      nsteps,\n",
    "      steps_per_plan,\n",
    "      p,\n",
    "      jax.device_put(multi_policy),\n",
    "      jax.random.split(jax.random.key(0), nplans),\n",
    "      mjx.put_model(sim_model_cpu),\n",
    "      sim_datas,\n",
    "  )\n",
    "  costs = np.sum(terms.reshape(nplans, -1, terms.shape[-1])[:, :, 2:13:2], axis=-1)\n",
    "  costs_to_compare[it] = costs\n",
    "\n",
    "  plt.figure()\n",
    "  plt.xlim([0, nsteps * sim_model_cpu.opt.timestep])\n",
    "  plt.ylim([0, max(costs.flatten())])\n",
    "  plt.xlabel('time')\n",
    "  plt.ylabel('cost')\n",
    "  x_time = [i * sim_model_cpu.opt.timestep for i in range(nsteps)]\n",
    "  for i in range(nplans):\n",
    "    plt.plot(x_time, costs[i, :], alpha=0.1)\n",
    "  avg = np.mean(costs, axis=0)\n",
    "  plt.plot(x_time, avg, linewidth=2.0)\n",
    "  var = np.var(costs, axis=0)\n",
    "  plt.errorbar(\n",
    "      x_time,\n",
    "      avg,\n",
    "      yerr=var,\n",
    "      fmt='none',\n",
    "      ecolor='b',\n",
    "      elinewidth=1,\n",
    "      alpha=0.2,\n",
    "      capsize=0,\n",
    "  )\n",
    "\n",
    "  plt.show()\n",
    "  costs_to_compare[it] = costs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "5ic6_2F8g6W6"
   },
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.xlim([0, nsteps * sim_model_cpu.opt.timestep])\n",
    "plt.ylim([0, max(costs.flatten())])\n",
    "plt.xlabel('time')\n",
    "plt.ylabel('cost')\n",
    "x_time = [i * sim_model_cpu.opt.timestep for i in range(nsteps)]\n",
    "for val, costs in costs_to_compare.items():\n",
    "  avg = np.mean(costs, axis=0)\n",
    "  plt.plot(x_time, avg, label=str(val))\n",
    "  var = np.var(costs, axis=0)\n",
    "  plt.errorbar(\n",
    "      x_time, avg, yerr=var, fmt='none', elinewidth=1, alpha=0.2, capsize=0\n",
    "  )\n",
    "\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "GkfgJQohg6W6"
   },
   "outputs": [],
   "source": [
    "d = mujoco.MjData(sim_model_cpu)\n",
    "sys = mjcf.load_model(sim_model_cpu)\n",
    "xstates = []\n",
    "for qpos in trajectories.q[2].reshape(-1, sim_model_cpu.nq):\n",
    "  d.qpos = qpos\n",
    "  mujoco.mj_kinematics(sim_model_cpu, d)\n",
    "  x = brax_base.Transform(pos=d.xpos[1:].copy(), rot=d.xquat[1:].copy())\n",
    "  xstates.append(brax_base.State(q=None, qd=None, x=x, xd=None, contact=None))\n",
    "\n",
    "HTML(html.render(sys, xstates))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "7lHL2pxsShyH"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "l9lUn3zmWIwP",
    "Ir0HGCL_4-qw"
   ],
   "private_outputs": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
